Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse attention kernel for H100 (sm90) #20553

Merged
merged 1 commit into from
May 5, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented May 3, 2024

Description

Follow up of #20216 to add sparse attention kernel compiled by Triton for H100 (sm90).

  • Refine sparse attention v1 kernel compilation (remove some combinations)
  • compile kernels for v1 kernels
  • compile kernels for H100
  • run performance tests

Performane

Test setting batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8

We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90):

    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796

The following might be able to help performance on short sequence length. Need update operator spec:

  • Fall back to flash attention when total_sequence length < local_blocks * block_size

Motivation and Context

@tianleiwu tianleiwu marked this pull request as draft May 3, 2024 23:25
@tianleiwu tianleiwu force-pushed the tlwu/sparse_attention_sm90 branch from 603ce0e to ffe7c96 Compare May 4, 2024 01:16
@tianleiwu tianleiwu marked this pull request as ready for review May 4, 2024 01:18
@tianleiwu tianleiwu merged commit baaef59 into main May 5, 2024
94 of 95 checks passed
@tianleiwu tianleiwu deleted the tlwu/sparse_attention_sm90 branch May 5, 2024 02:53
@sophies927 sophies927 added the triage:approved Approved for cherrypicks for release label May 6, 2024
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Description
Follow up of microsoft#20216 to add
sparse attention kernel compiled by Triton for H100 (sm90).
- [x] Refine sparse attention v1 kernel compilation (remove some
combinations)
- [x] compile kernels for v1 kernels
- [x] compile kernels for H100
- [x] run performance tests

### Performane

Test setting `batch_size=4, num_heads=32, max_seq_len=8192,
head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8,
num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has
more computation than ORT-SparseAtt, while ORT-GQA-Local has less
computation (no vertial strides) than ORT-SparseAtt. They are added for
reference. It is not fair comparison, but could show the benefit of
sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA
H100-80GB-HBM3 GPU (sm=90):
```
    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796
```
The following might be able to help performance on short sequence
length. Need update operator spec:
 Fall back to flash attention when total_sequence length < local_blocks * block_size

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
@yihonglyu yihonglyu added the cherry-picked Cherry-picked for a cherrypicks branch label May 9, 2024
yihonglyu pushed a commit that referenced this pull request May 9, 2024
### Description
Follow up of #20216 to add
sparse attention kernel compiled by Triton for H100 (sm90).
- [x] Refine sparse attention v1 kernel compilation (remove some
combinations)
- [x] compile kernels for v1 kernels
- [x] compile kernels for H100
- [x] run performance tests

### Performane

Test setting `batch_size=4, num_heads=32, max_seq_len=8192,
head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8,
num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has
more computation than ORT-SparseAtt, while ORT-GQA-Local has less
computation (no vertial strides) than ORT-SparseAtt. They are added for
reference. It is not fair comparison, but could show the benefit of
sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA
H100-80GB-HBM3 GPU (sm=90):
```
    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796
```
The following might be able to help performance on short sequence
length. Need update operator spec:
 Fall back to flash attention when total_sequence length < local_blocks * block_size

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
@yihonglyu yihonglyu added the rel-merged Cherrypicks merged into release label May 10, 2024
poweiw pushed a commit to poweiw/onnxruntime that referenced this pull request Jun 25, 2024
### Description
Follow up of microsoft#20216 to add
sparse attention kernel compiled by Triton for H100 (sm90).
- [x] Refine sparse attention v1 kernel compilation (remove some
combinations)
- [x] compile kernels for v1 kernels
- [x] compile kernels for H100
- [x] run performance tests

### Performane

Test setting `batch_size=4, num_heads=32, max_seq_len=8192,
head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8,
num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has
more computation than ORT-SparseAtt, while ORT-GQA-Local has less
computation (no vertial strides) than ORT-SparseAtt. They are added for
reference. It is not fair comparison, but could show the benefit of
sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA
H100-80GB-HBM3 GPU (sm=90):
```
    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796
```
The following might be able to help performance on short sequence
length. Need update operator spec:
 Fall back to flash attention when total_sequence length < local_blocks * block_size

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch rel-merged Cherrypicks merged into release release:1.18.0 triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants